import cv2
import os
from tqdm import tqdm
import torch
import numpy as np
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision.transforms import ToTensor
# from diffusers import UNet2DConditionModel, StableDiffusionControlNetPipeline
from PIL import Image
from omegaconf import OmegaConf
from accelerate.utils import set_seed
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights

from model.cldm_depth import ControlLDM
from model.gaussian_diffusion import Diffusion
from pipeline.pipeline import DepthPipeline
from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage
from utils.flow import backward_warp, forward_backward_consistency_check

from pipeline.pipeline import pad_to_multiples_of



parser = ArgumentParser()
parser.add_argument("--input_lq", default=None, type=str)
parser.add_argument("--input_depth", default=None, type=str)
parser.add_argument("--output", default='results', type=str)
parser.add_argument("--tiled", action='store_true', help='crop into specific pacthes for inference, which can help reduce GPU memory')
parser.add_argument("--tile_size", default=512, type=int)
parser.add_argument("--tile_stride", default=256, type=int)
parser.add_argument("--ckpt", default='', type=str)
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
parser.add_argument("--seed", type=int, default=231)
parser.add_argument("--fidelity", type=str, default='/media/ubuntu/data/HZW/Diff-RDRF/pretrained_models/vae_gate_0300000.pt')
parser.add_argument("--fidelity_input", type=str, default='lq', choices=['lq', 'stage1', 'both', 'None'])
args = parser.parse_args()

### make result dir
os.makedirs(args.output, exist_ok=True)

### load uent, vae, clip
cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/train/train_depth.yaml").model.cldm)
sd = torch.load('/media/ubuntu/data/HZW/Diff-RDRF/v2-1_512-ema-pruned.ckpt', map_location="cpu")["state_dict"]
unused = cldm.load_pretrained_sd(sd)
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")

### load controlnet
# control_sd = torch.load('exps/V1/checkpoints/{}.pt'.format(args.ckpt_epoch), map_location="cpu")
control_sd = torch.load(args.ckpt, map_location="cpu")
cldm.load_controlnet_from_ckpt(control_sd)
print(f"strictly load controlnet weight")
cldm.eval().to(args.device)

### load fidelity encoder
fidelity_encoder = instantiate_from_config(OmegaConf.load("configs/train/train_depth.yaml").model.fidelity_encoder)
# fidelity_encoder_sd = torch.load('/media/ubuntu/data/HZW/Diff-RDRF/pretrained_models/vae_0150000.pt')
fidelity_encoder_sd = torch.load(args.fidelity)
fidelity_encoder.load_state_dict(fidelity_encoder_sd, strict=True)
fidelity_encoder = fidelity_encoder.cuda()
fidelity_encoder.eval()
### load diffusion
diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/train/train_depth.yaml").model.diffusion)
diffusion.to(args.device)

pipe = DepthPipeline(cldm=cldm, diffusion=diffusion, fidelity_encoder=fidelity_encoder, device=args.device)
to_tensor = ToTensor()


if os.path.isdir(args.input_lq):
    inputs = os.listdir(args.input_lq)
    inputs.sort()
    for img in tqdm(inputs, desc='input', leave=False):
        img_name = img.split('.')[0]
        img_path = os.path.join(args.input_lq, img)

        # depth_path = os.path.join(args.input_depth, f"{img.split('-')[0]}.png") # when use depth_ref
        # depth_path = os.path.join(args.input_depth, img)
        depth_path = os.path.join(args.input_depth, img_name + '.png')

        img1 = Image.open(img_path).convert('RGB')
        depth = Image.open(depth_path).convert('RGB')
        # print(depth.size)

        img1 = to_tensor(img1).unsqueeze(dim=0).cuda()
        depth = to_tensor(depth).unsqueeze(dim=0).cuda()
        _, _, h, w = img1.size()
        # depth = depth[:, :, :h, :w]
        
        
        set_seed(args.seed)
        if not args.tiled:
            NotImplemented
            # img1_pad = pad_to_multiples_of(imgs= img1, multiple=64) # shoule be divided by 64
            # depth_pad = pad_to_multiples_of(imgs= depth, multiple=64) # shoule be divided by 64
            # out = pipe.run(lq = img1_pad, depth = depth_pad, tiled = args.tiled) # lq: [0, 1], inside the pipeline, it will be converted to [-1, 1]
        else:
            img1_pad = pad_to_multiples_of(imgs= img1, multiple=8) # shoule be divided by 8
            depth_pad = pad_to_multiples_of(imgs= depth, multiple=8) # shoule be divided by 8
            if args.fidelity_input == 'lq':
                out = pipe.run(lq = img1_pad, 
                            depth = depth_pad, 
                            tiled = args.tiled, 
                            tile_size = args.tile_size, 
                            tile_stride = args.tile_stride, 
                            fidelity_input=img1_pad) # lq: [0, 1], inside the pipeline, it will be converted to [-1, 1]
            elif args.fidelity_input == 'stage1':
                out = pipe.run(lq = img1_pad, 
                            depth = depth_pad, 
                            tiled = args.tiled, 
                            tile_size = args.tile_size, 
                            tile_stride = args.tile_stride, 
                            fidelity_input=depth_pad) # lq: [0, 1], inside the pipeline, it will be converted to [-1, 1]
            elif args.fidelity_input == 'both':
                out = pipe.run(lq = img1_pad, 
                            depth = depth_pad, 
                            tiled = args.tiled, 
                            tile_size = args.tile_size, 
                            tile_stride = args.tile_stride, 
                            fidelity_input=[img1_pad, depth_pad]) # lq: [0, 1], inside the pipeline, it will be converted to [-1, 1]
            else:
                out = pipe.run(lq = img1_pad, 
                            depth = depth_pad, 
                            tiled = args.tiled, 
                            tile_size = args.tile_size, 
                            tile_stride = args.tile_stride, 
                            fidelity_input=None) # lq: [0, 1], inside the pipeline, it will be converted to [-1, 1]
        out = out[:, :, :h, :w]

        save_image((out + 1) / 2, '{}/{}_o.png'.format(args.output, img_name))
        # save_image(img1, '{}/{}_i.png'.format(args.output, img_name))